{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SimulatedQuantize\n",
    "\n",
    "源码：`tvm/src/relay/quantize/quantize.cc` 和 `tvm/python/tvm/relay/quantize/_annotate.py`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import set_env"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "::::{dropdown}\n",
    "```c++\n",
    "TVM_REGISTER_NODE_TYPE(SimulatedQuantizeAttrs);\n",
    "\n",
    "bool SimulatedQuantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,\n",
    "                          const TypeReporter& reporter) {\n",
    "  ICHECK_EQ(types.size(), 5);\n",
    "  const auto param = attrs.as<SimulatedQuantizeAttrs>();\n",
    "  ICHECK(param != nullptr);\n",
    "\n",
    "  const auto* data = types[0].as<TensorTypeNode>();\n",
    "\n",
    "  if (data == nullptr) {\n",
    "    return false;\n",
    "  }\n",
    "\n",
    "  ICHECK_NE(data->shape.size(), 0) << \"Input shape cannot be empty\";\n",
    "\n",
    "  reporter->Assign(types[1], TensorType({}, DataType::Float(32)));  // dom_scale\n",
    "  reporter->Assign(types[2], TensorType({}, DataType::Float(32)));  // clip_min\n",
    "  reporter->Assign(types[3], TensorType({}, DataType::Float(32)));  // clip_max\n",
    "  reporter->Assign(types[4], types[0]);                             // output\n",
    "  return true;\n",
    "}\n",
    "```\n",
    "::::"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这段代码定义了 `SimulatedQuantizeRel` 函数，它的作用是检查输入的类型是否符合预期。具体来说，它首先检查输入的类型数量是否为 5，然后从属性中获取 `SimulatedQuantizeAttrs` 类型的参数。接着，它检查第一个类型是否为 `TensorTypeNode` 类型，如果不是则返回 `false`。最后，它将输出的类型分别设置为 `dom_scale`、`clip_min`、`clip_max` 和输入数据的类型。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "::::{dropdown}\n",
    "```c++\n",
    "RELAY_REGISTER_OP(\"relay.op.annotation.simulated_quantize\")\n",
    "    .describe(R\"code(simulated quantize op)code\" TVM_ADD_FILELINE)\n",
    "    .set_num_inputs(4)\n",
    "    .add_argument(\"data\", \"Tensor\", \"The input data.\")\n",
    "    .add_argument(\"dom_scale\", \"Tensor\", \"The domain scale of input data. It should be a scalar\")\n",
    "    .add_argument(\"clip_min\", \"Tensor\", \"lower bound. It should be a scalar\")\n",
    "    .add_argument(\"clip_max\", \"Tensor\", \"upper bound. It should be a scalar\")\n",
    "    .set_attrs_type<SimulatedQuantizeAttrs>()\n",
    "    .set_support_level(11)\n",
    "    .add_type_rel(\"SimulatedQuantize\", SimulatedQuantizeRel);\n",
    "\n",
    "TVM_REGISTER_GLOBAL(\"relay._quantize.simulated_quantize\")\n",
    "    .set_body_typed([](Expr data, Expr dom_scale, Expr clip_min, Expr clip_max, int kind, bool sign,\n",
    "                       String rounding) {\n",
    "      auto attrs = make_object<SimulatedQuantizeAttrs>();\n",
    "      attrs->kind = kind;\n",
    "      attrs->sign = sign;\n",
    "      attrs->rounding = rounding;\n",
    "      static const Op& op = Op::Get(\"relay.op.annotation.simulated_quantize\");\n",
    "      return Call(op, {data, dom_scale, clip_min, clip_max}, Attrs(attrs), {});\n",
    "    });\n",
    "\n",
    "```\n",
    "::::"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`RELAY_REGISTER_OP` 宏注册名为 `relay.op.annotation.simulated_quantize` 的算子，该算子有 4 个输入参数：`data`、`dom_scale`、`clip_min` 和 `clip_max`。它还设置了属性类型为 `SimulatedQuantizeAttrs`，并添加了类型关系函数 `SimulatedQuantizeRel`。\n",
    "\n",
    "`TVM_REGISTER_GLOBAL` 宏注册全局函数 `relay._quantize.simulated_quantize`，该函数接受 6 个参数：`data`、`dom_scale`、`clip_min`、`clip_max`、`kind`、`sign` 和 `rounding`。在这个函数中，首先创建 `SimulatedQuantizeAttrs` 对象，并设置其属性值。然后，调用 `relay.op.annotation.simulated_quantize` 算子，并将结果返回。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tvm-env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
